import cv2
import  xml.etree.ElementTree as ET 
import numpy as np
import os
import json
import shutil
import base64
'''
该脚本实现将xml类型标签(或者yolo格式标签)转为json格式标签
需要的数据：原始图像 原始xml标签（原始txt标签）

'''

# 解析数据集，输入单张图片路径，图片路径不能出现中文，因为是cv2读取的。和对应xml文件的路径
# 返回图片 该图所有的目标框[[x1,y1,x2,y2],....]  每个框的类别[label1, label2, label3,.....]  注意是label而不是索引
def parse_img_label(img_path, xml_path):  # 绝对路径
    img = cv2.imread(img_path)
    tree = ET.parse(xml_path) 
    root = tree.getroot()
    objs = root.findall('object')
    bboxes = []  # 坐标框
    h ,w = img.shape[0], img.shape[1]
    #gt_labels = []  # 标签名
    for obj in objs: # 遍历所有的目标
        label = obj[0].text  # <name>这个tag的值，即标签
        label = label.strip(' ')
        box = [int(obj[4][i].text) for i in range(4)]
        box.append(label)  # box的元素 x1 y1 x2 y2 类别
        bboxes.append(box)
    return img, bboxes

# 该函数用于将yolo的标签转回xml需要的标签。。即将归一化后的坐标转为原始的像素坐标
def convert_yolo_xml(box,img):  # 
    x,y,w,h = box[0], box[1], box[2], box[3]
    # 求出原始的x1 x2 y1 y2
    x2 = (2*x + w)*img.shape[1] /2
    x1 = x2 - w*img.shape[1]

    y2 = (2*y+h)*img.shape[0] /2
    y1 = y2 - h* img.shape[0]
    new_box = [x1,y1, x2, y2]
    new_box = list(map(int,new_box))
    return new_box

# 该函数用于解析yolo格式的数据集，即txt格式的标注 返回图像 边框坐标 真实标签名（不是索引，因此需要预先定义标签）
def parse_img_txt(img_path, txt_path):
    name_label = ['class0','class1','class2']  # 需要自己预先定义,它的顺序要和实际yolo格式的标签中0 1 2 3的标签对应 yolo标签的类别是索引 而不是名字
    img = cv2.imread(img_path)
    f = open(txt_path)
    bboxes = []
    for line in f.readlines():
        line = line.split(" ")
        if len(line) == 5:
            obj_label = name_label[int(line[0])] # 将类别索引转成其名字
            x = float(line[1])
            y = float(line[2])
            w = float(line[3])
            h = float(line[4])
            box = convert_yolo_xml([x,y,w,h], img)
            box.append(obj_label)
            bboxes.append(box)
    return img, bboxes



# 制作labelme格式的标签
# 参数说明 img_name： 图像文件名称 
# txt_name: 标签文件的绝对路径，注意是绝对路径
# prefix： 图像文件的上级目录名。即形如/home/xjzh/data/ 而img_name是其下的文件名，如00001.jpg
# prefix+img_name即为图像的绝对路径。不该路径能出现中文，否则cv2读取会有问题
# 
def get_json(img_name, txt_name, prefix, yolo=False):
    # 图片名 标签名 前缀
    label_dict = {}  # json字典，依次填充它的value 
    label_dict["imagePath"] = prefix + img_name  # 图片路径
    label_dict["fillColor"] = [255,0,0,128]  # 目标区域的填充颜色 RGBA
    label_dict["lineColor"] = [0,255,0,128]  # 线条颜色
    label_dict["flag"] = {}
    label_dict["version"] = "3.16.7"  # 版本号，随便
    with open(prefix + img_name,"rb") as f:
        img_data = f.read()
        base64_data = base64.b64encode(img_data)
        base64_str = str(base64_data, 'utf-8')
        label_dict["imageData"] = base64_str  # labelme的json文件存放了图像的base64编码。这样如果图像路径有问题仍然能够打开文件

    img, gt_box = parse_img_label(prefix + img_name, txt_name) if not yolo else parse_img_txt(prefix + img_name, txt_name)  # 读取真实数据
    
    label_dict["imageHeight"] = img.shape[0]  # 高度
    label_dict["imageWidth"] = img.shape[1]

    shape_list = [] # 存放标注信息的列表，它是 shapes这个键的值。里面是一个列表，每个元素又是一个字典，字典内容是该标注的类型 颜色 坐标点等等
    #label_dict["shapes"] = [] # 列表，每个元素是字典。
    # box的元素 x1 y1 x2 y2 类别
    for box in gt_box:
        shape_dict = {}  # 表示一个目标的字典
        shape_dict["shape_type"] = "rectangle"  # 因为xml或yolo格式标签是矩形框标注，因此是rectangle
        shape_dict["fill_color"] = None  #该类型的填充颜色 
        shape_dict["line_color"] = None  # 线条颜色 可以设置，或者根据标签名自己预先设定labe_color_dict
        shape_dict["flags"] = {}
        shape_dict["label"] = box[-1] # 标签名  
        shape_dict["points"] = [[box[0],box[1]], [box[2], box[3]]] 
        # 通常contours是长度为1的列表，如果有分块，可能就有多个  # [[x1,y1], [x2,y2]...]的列表
        shape_list.append(shape_dict)
    
    label_dict["shapes"] = shape_list  #
    return label_dict

imgs_path = "/home/xjzh/fgd/JPEGImages/"  # 图像路径
xmls_path ="/home/xjzh/fgd/Annotations/" # xml文件路径

img_path = os.listdir(imgs_path)
out_json = '/home/xjzh/DATA/JSON_data/'  # 保存的json文件路径

for nums, path in enumerate(img_path):
    if nums %200==0:
        print(f"processed {nums} images")
    xml_path = xmls_path + path.replace('jpg','xml')  # xml文件的绝对路径
    label_dict = get_json(path, xml_path,prefix=imgs_path)  # 
    with open(out_json + path.replace("jpg","json"),'w') as f: # 写入一个json文件
        f.write(json.dumps(label_dict, ensure_ascii=False, indent=4, separators=(',', ':')))